import os
import csv
from p_tqdm import p_map
from argparse import ArgumentParser
from copy import deepcopy

import numpy as np

from learner import SumUCB, GaussTS
from learner import UpUCB, UpUCB_L, UpUCB_Gap
from bandit import GaussianUpliftBandit, interact


def main(params):

    dirname = f'arms_{params.n_arms}'
    dirname = dirname + f'-variables_{params.n_variables}'
    dirname = dirname + f'-affected_{params.n_affected}'
    dirname = dirname + f'-mlift_{params.minimum_uplift}'

    if params.inidependent_noise:
        dirname = dirname + '-independent_noise'
        cov_mat_decomp = np.eye(params.n_variables)
    else:
        cov_mat_decomp = None

    dirname = dirname + f'-rng_{params.random_seed}'
    dirname = dirname.replace('.', '_')
    dirname = os.path.join(params.save_dir, dirname)

    if not os.path.exists(dirname):
        os.makedirs(dirname)

    bandit = GaussianUpliftBandit(params.n_arms,
                                  params.n_variables,
                                  params.n_affected,
                                  minimum_uplift=params.minimum_uplift,
                                  A=cov_mat_decomp,
                                  rng_initialize=params.random_seed,
                                  rng_covmat=params.random_seed)

    filename = f'{params.algo}'

    if params.algo == 'UpUCB_L':
        filename = filename + f'-{params.n_affected_learner}'
    if params.algo == 'UpUCB_Gap':
        filename = filename + f'-{params.minimum_uplift_learner}'
    # if params.algo == 'TS':
    #     filename = filename + f'-prior_{params.prior_mean}_{params.prior_var}'

    if params.use_baseline:
        filename = filename + '-with_baseline'
        baseline = bandit.baseline
    else:
        filename = filename + '-without_baseline'
        baseline = None

    filename = filename + f'-{params.baseline_option}-radius_{params.radius}'
    filename = filename.replace('.', '_')
    filename = filename + f'-runs_{params.n_runs}'
    filename = filename + f'-rounds_{params.n_rounds}'

    csv_file = os.path.join(dirname, filename + '-params.csv')

    with open(csv_file, 'w') as csvfile:
        writer = csv.writer(csvfile)
        for key, value in params.__dict__.items():
            writer.writerow([key, value])

    if params.algo == 'UCB':
        learner = SumUCB(params.n_arms,
                         params.n_variables,
                         params.radius)

    if params.algo == 'TS':
        prior_mean = np.mean(bandit.rewards)
        prior_var = np.var(bandit.rewards)
        learner = GaussTS(params.n_arms,
                          params.radius,
                          prior_mean,
                          prior_var)

    if params.algo == 'UpUCB':
        learner = UpUCB(params.n_arms,
                        params.n_variables,
                        bandit.affected_sets,
                        baseline,
                        params.baseline_option,
                        params.radius)

    if params.algo == 'UpUCB_L':
        learner = UpUCB_L(params.n_arms,
                          params.n_variables,
                          params.n_affected_learner,
                          baseline,
                          params.radius)

    if params.algo == 'UpUCB_Gap':
        learner = UpUCB_Gap(params.n_arms,
                            params.n_variables,
                            params.minimum_uplift_learner,
                            baseline,
                            params.radius)

    args = []
    for i in range(params.random_seed, params.random_seed + params.n_runs):
        args.append((bandit, learner, params.n_rounds, params.print_step, i))
    results = p_map(run_para, args)
    regrets = np.cumsum(np.vstack([result[0] for result in results]), axis=1)
    regrets = regrets[:, ::params.save_every]
    arm_his = np.vstack([result[1] for result in results])
    arm_his = arm_his[:, ::params.save_every]

    if params.save_all:
        filename_regrets = os.path.join(dirname, filename + '-regrets.npy')
        filename_arm_his = os.path.join(dirname, filename + '-arm_his.npy')
        np.save(filename_regrets, regrets)
        np.save(filename_arm_his, arm_his)
    else:
        mean_regrets = np.mean(regrets, axis=0)
        std_regrets = np.std(regrets, axis=0)
        filename_mean = os.path.join(dirname, filename + '-regrets-mean.npy')
        filename_std = os.path.join(dirname, filename + '-regrets-std.npy')
        np.save(filename_mean, mean_regrets)
        np.save(filename_std, std_regrets)


def run(bandit, learner, n_rounds, print_step, rng):
    bandit = deepcopy(bandit)
    bandit.init_feedback(rng)
    learner = deepcopy(learner)
    learner.set_rng(rng)
    regrets, arm_his = interact(bandit, learner, n_rounds, print_step)
    return regrets, arm_his


def run_para(args):
    regrets, arm_his = run(*args)
    return regrets, arm_his


if __name__ == '__main__':

    parser = ArgumentParser()

    parser.add_argument('--save_dir', type=str, default='save/gaussian/')
    parser.add_argument('--save_every', type=int, default=10)

    parser.add_argument('--n_arms', type=int, default=10)
    parser.add_argument('--n_variables', type=int, default=100)
    parser.add_argument('--n_affected', type=int, default=10)
    parser.add_argument('--minimum_uplift', type=float, default=0.1)

    parser.add_argument('--independent_noise', dest='inidependent_noise', action='store_true')
    parser.add_argument('--no_baseline', dest='use_baseline', action='store_false')
    parser.add_argument('--save_mean', dest='save_all', action='store_false')
    parser.set_defaults(use_baseline=True, inidependent_noise=False, save_all=True)

    parser.add_argument('--algo', type=str, default='UpUCB')
    parser.add_argument('--baseline_option', type=str, default='UCB')
    parser.add_argument('--radius', type=float, default=3)

    parser.add_argument('--n_runs', type=int, default=3)
    parser.add_argument('--n_rounds', type=int, default=1000)
    parser.add_argument('--n_affected_learner', type=int, default=10)
    parser.add_argument('--minimum_uplift_learner', type=float, default=0.1)

    # parser.add_argument('--prior_mean', type=float, default=0)
    # parser.add_argument('--prior_var', type=float, default=100)

    parser.add_argument('--print_step', type=int, default=10000)

    parser.add_argument('--random_seed', type=int, default=3)
    # parser.add_argument('--n_cpus', type=int, default=10)

    params = parser.parse_args()

    main(params)
